import torch
import numpy as np
import matplotlib.pyplot as plt


class Op(object):
    def __init__(self):
        pass

    def __call__(self, inputs):
        return self.forward(inputs)

    # 输入：张量inputs
    # 输出：张量outputs
    def forward(self, inputs):
        # return outputs
        raise NotImplementedError

    # 输入：最终输出对outputs的梯度outputs_grads
    # 输出：最终输出对inputs的梯度inputs_grads
    def backward(self, outputs_grads):
        # return inputs_grads
        raise NotImplementedError


class OptimizedFunction3D(Op):
    def __init__(self):
        super(OptimizedFunction3D, self).__init__()
        self.params = {'x': 0}
        self.grads = {'x': 0}

    def forward(self, x):
        self.params['x'] = x
        return x[0] ** 2 + x[1] ** 2 + x[1] ** 3 + x[0] * x[1]

    def backward(self):
        x = self.params['x']
        gradient1 = 2 * x[0] + x[1]
        gradient2 = 2 * x[1] + 3 * x[1] ** 2 + x[0]
        grad1 = torch.Tensor([gradient1])
        grad2 = torch.Tensor([gradient2])
        self.grads['x'] = torch.cat([grad1, grad2])


x1 = np.arange(-3, 3, 0.1)
x2 = np.arange(-3, 3, 0.1)
x1, x2 = np.meshgrid(x1, x2)
init_x = torch.Tensor(np.array([x1, x2]))

model = OptimizedFunction3D()

# 绘制 f_3d函数 的 三维图像
fig = plt.figure()
ax = plt.axes(projection='3d')

X = init_x[0].numpy()
Y = init_x[1].numpy()
Z = model(init_x).numpy()  # 改为 model(init_x).numpy() David 2022.12.4
ax.plot_surface(X, Y, Z, cmap='rainbow')

ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('f(x1,x2)')
plt.show()
